import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import torch_npu
import torchvision.datasets as datasets
import torchvision.models as models
from torch_npu.npu import amp
from torch.utils.tensorboard import SummaryWriter
import datetime
import torchvision.transforms as transforms
import shutil

model_path = "models"
device = torch.device('npu:0')
tensorboard = SummaryWriter(log_dir=os.path.join(model_path, "tensorboard", f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"))
best_accuracy = 0


class AverageMeter(object):
    """
    Computes and stores the average and current value
    """
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    """
    Progress metering
    """
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('  '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target):
    """
    Computes the accuracy of predictions vs groundtruth
    """
    with torch.no_grad():

        output = F.softmax(output, dim=-1)
        _, preds = torch.max(output, dim=-1)
        preds = (preds == target)
            
        return preds.float().mean().cpu().item() * 100.0
        
def train(train_loader, model, criterion, optimizer,scaler, epoch):
    """
    Train one epoch over the dataset
    """
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    acc = AverageMeter('Accuracy', ':7.3f')
    
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, acc],
        prefix=f"Epoch: [{epoch}]")

    # switch to train mode
    model.train()

    # get the start time
    epoch_start = time.time()
    end = epoch_start

    # train over each image batch from the dataset
    for i, (images, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)
    
        images = images.to(device,non_blocking=True)
        target = target.to(device,non_blocking=True)

        # compute output
        with amp.autocast():
            output = model(images)
            loss = criterion(output, target)
    
        # record loss and measure accuracy
        losses.update(loss.item(), images.size(0))
        acc.update(accuracy(output, target), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 50 == 0 or i == len(train_loader)-1:
            progress.display(i)
    
    print(f"Epoch: [{epoch}] completed, elapsed time {time.time() - epoch_start:6.3f} seconds")

    tensorboard.add_scalar('Loss/train', losses.avg, epoch)
    tensorboard.add_scalar('Accuracy/train', acc.avg, epoch)
    return losses.avg, acc.avg

def validate(val_loader, model, criterion, epoch):
    """
    Measure model performance across the val dataset
    """
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    acc = AverageMeter('Accuracy', ':7.3f')
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, acc],
        prefix='Val:   ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device,non_blocking=True)
            target = target.to(device,non_blocking=True)
            # compute output
            with amp.autocast():
                output = model(images)
                loss = criterion(output, target)
            # record loss and measure accuracy
            losses.update(loss.item(), images.size(0))
            acc.update(accuracy(output, target), images.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % 10  == 0 or i == len(val_loader)-1:
                progress.display(i)

    tensorboard.add_scalar('Loss/val', losses.avg, epoch)
    tensorboard.add_scalar('Accuracy/val', acc.avg, epoch)
    
    return losses.avg, acc.avg
    
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar', labels_filename='labels.txt'):
    """
    Save a model checkpoint file, along with the best-performing model if applicable
    """
    model_dir = os.path.expanduser(model_path)

    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    filename = os.path.join(model_dir, filename)
    best_filename = os.path.join(model_dir, best_filename)
    labels_filename = os.path.join(model_dir, labels_filename)
        
    # save the checkpoint
    torch.save(state, filename)
            
    # earmark the best checkpoint
    if is_best:
        shutil.copyfile(filename, best_filename)
        print(f"saved best model to:  {best_filename}")
    else:
        print(f"saved checkpoint to:  {filename}")
        
    # save labels.txt on the first epoch
    if state['epoch'] == 0:
        with open(labels_filename, 'w') as file:
            for label in state['classes']:
                file.write(f"{label}\n")
        print(f"saved class labels to:  {labels_filename}")
            
    
def main():
    global best_accuracy
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = datasets.ImageFolder("./dataset/train", train_transforms)
    val_dataset = datasets.ImageFolder("./dataset/val", val_transforms)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=8, shuffle=True,
        num_workers=3, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=16, shuffle=False,
        num_workers=3, pin_memory=True)
    model = models.resnet18(pretrained=True)
    num_classes = len(train_dataset.classes)
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    lr = 0.1
    momentum = 0.9 
    weight_decay = 1e-4
    optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    scaler = amp.GradScaler()
    epochs = 10
    for epoch in range(epochs):
        
        train_loss, train_acc = train(train_loader, model, criterion, optimizer,scaler, epoch)
        val_loss, val_acc = validate(val_loader, model, criterion, epoch)

        # remember best acc@1 and save checkpoint
        is_best = val_acc > best_accuracy
        best_accuracy = max(val_acc, best_accuracy)

        print(f"=> Epoch {epoch}")
        print(f"  * Train Loss     {train_loss:.4e}")
        print(f"  * Train Accuracy {train_acc:.4f}")
        print(f"  * Val Loss       {val_loss:.4e}")
        print(f"  * Val Accuracy   {val_acc:.4f}{'*' if is_best else ''}")
        
        save_checkpoint({
            'epoch': epoch,
            'arch': "resnet18",
            'resolution': 224,
            'classes': train_dataset.classes,
            'num_classes': len(train_dataset.classes),
            'multi_label': False,
            'state_dict': model.state_dict(),
            'accuracy': {'train': train_acc, 'val': val_acc},
            'loss' : {'train': train_loss, 'val': val_loss},
            'optimizer' : optimizer.state_dict(),
        }, is_best)

if __name__ == '__main__':
    main()
